package main

import (
	"bytes"
	"context"
	"crypto/rand"
	"crypto/rsa"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"fmt"
	"io"
	"math/big"
	"net/http"
	"net/url"
	"os"
	"path/filepath"
	"strings"
	"time"

	"github.com/sirupsen/logrus"
	"github.com/spf13/cobra"
	corev1 "k8s.io/api/core/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/apimachinery/pkg/types"
	"sigs.k8s.io/controller-runtime/pkg/client"

	"github.com/openshift/operator-framework-olm/pkg/profiling/config"
)

const (
	profileConfigMapLabelKey = "olm.openshift.io/pprof"
	pprofSecretName          = "pprof-cert"
)

var (
	rootCmd = newCmd()

	// Used for flags
	namespace       string
	configMountPath string
	certMountPath   string
)

func init() {
	rootCmd.PersistentFlags().StringVarP(&namespace, "namespace", "n", "default", "The Kubernetes namespace where the generated configMaps should exist. Defaults to \"default\".")
	rootCmd.MarkFlagRequired("namespace")
	rootCmd.PersistentFlags().StringVarP(&configMountPath, "config-mount-path", "c", "/etc/config", "The path to the collect-profiles configuration file.")
	rootCmd.MarkFlagRequired("config-mount-path")
	rootCmd.PersistentFlags().StringVarP(&certMountPath, "cert-mount-path", "", "/var/run/secrets/serving-cert", "The path to the tls cert used by the client making https requests against the pprof URLs.")
}

func main() {
	Execute()
}

func Execute() {
	if err := rootCmd.Execute(); err != nil {
		logrus.Fatal(err)
		os.Exit(1)
	}
}

func getTruePointer() *bool {
	trueBool := true
	return &trueBool
}

func newCmd() *cobra.Command {
	var cfg config.Configuration
	return &cobra.Command{
		Use:   "collect-profiles configMapName:url",
		Short: "Retrieves the pprof data from a URL and stores it in a configMap.",
		Long: `The collect-profiles command makes https requests against pprof URLs
		provided as arguments and stores that information in immutable configMaps.
		
		# Example command with multiple arguments
		./collect-profiles -n - openshift-operator-lifecycle-manager \
                - --config-mount-path \
                - /etc/config \
                - --cert-mount-path \
                - /var/run/secrets/serving-cert \
                - olm-operator-heap-:https://olm-operator-metrics:8443/debug/pprof/heap \
                - catalog-operator-heap-:https://catalog-operator-metrics:8443/debug/pprof/heap
		`,
		SilenceUsage: true,
		PersistentPreRunE: func(*cobra.Command, []string) error {
			return cfg.Load()
		},
		RunE: func(cmd *cobra.Command, args []string) error {

			if len(args) == 0 {
				logrus.Info("No arguments provided, exiting")
				return nil
			}

			jobConfig, err := config.GetConfig(configMountPath)
			if err != nil {
				logrus.Infof("error retrieving job config")
				return err
			}

			// Exit if job is disabled
			if jobConfig.Disabled {
				logrus.Infof("CronJob disabled, exiting")
				return nil
			}

			// Validate input
			validatedArguments := make([]*argument, len(args))
			for i, arg := range args {
				a, err := newArgument(arg)
				if err != nil {
					return err
				}
				validatedArguments[i] = a
			}

			// Get existing configmaps
			existingConfigMaps := &corev1.ConfigMapList{}
			if err := cfg.Client.List(cmd.Context(), existingConfigMaps, client.InNamespace(namespace), client.HasLabels{profileConfigMapLabelKey}); err != nil {
				return err
			}

			newestConfigMaps, expiredConfigMaps := separateConfigMapsIntoNewestAndExpired(existingConfigMaps.Items)

			// Attempt to delete all but the newest configMaps generated by this job
			errs := []error{}
			for _, cm := range expiredConfigMaps {
				if err := cfg.Client.Delete(cmd.Context(), &cm); err != nil {
					errs = append(errs, err)
					continue
				}
				logrus.Infof("Successfully deleted configMap %s/%s", cm.GetNamespace(), cm.GetName())
			}

			// If a delete call failed, abort to avoid creating new configMaps
			if len(errs) != 0 {
				return fmt.Errorf("error deleting expired pprof configMaps: %v", errs)
			}

			certPath := filepath.Join(certMountPath, corev1.TLSCertKey)
			keyPath := filepath.Join(certMountPath, corev1.TLSPrivateKeyKey)

			if err := verifyCertAndKeyExist(certPath, keyPath); err != nil {
				logrus.Infof("error verifying provided cert and key: %v", err)
				logrus.Info("generating a new cert and key")
				return populateServingCert(cmd.Context(), cfg.Client)
			}

			httpClient, err := getHttpClient(certPath, keyPath)
			if err != nil {
				return err
			}

			// Track successfully created configMaps by generateName for each endpoint being scrapped.
			createdCM := map[string]struct{}{}

			for _, a := range validatedArguments {
				b, err := requestURLBody(httpClient, a.url)
				if err != nil {
					logrus.Infof("error retrieving pprof profile: %v", err)
					continue
				}

				cm := &corev1.ConfigMap{
					ObjectMeta: metav1.ObjectMeta{
						GenerateName: a.generateName,
						Namespace:    namespace,
						Labels: map[string]string{
							profileConfigMapLabelKey: "",
						},
					},
					Immutable: getTruePointer(),
					BinaryData: map[string][]byte{
						"profile.pb.gz": b,
					},
				}

				if err := cfg.Client.Create(cmd.Context(), cm); err != nil {
					logrus.Errorf("error created configMap %s/%s: %v", cm.GetNamespace(), cm.GetName(), err)
					continue
				}

				logrus.Infof("Successfully created configMap %s/%s", cm.GetNamespace(), cm.GetName())
				createdCM[a.generateName] = struct{}{}
			}

			// Delete the configMaps which are no longer the newest
			for _, cm := range newestConfigMaps {
				// Don't delete ConfigMaps that were not replaced
				// Also prevents deletes of configMaps with generateNames not included in command
				if _, ok := createdCM[cm.GenerateName]; !ok {
					continue
				}
				if err := cfg.Client.Delete(cmd.Context(), &cm); err != nil {
					errs = append(errs, err)
					continue
				}
				logrus.Infof("Successfully deleted configMap %s/%s", cm.GetNamespace(), cm.GetName())
			}

			if len(errs) != 0 {
				return fmt.Errorf("error deleting existing pprof configMaps: %v", errs)
			}

			// Update serving cert after a successful run
			return populateServingCert(cmd.Context(), cfg.Client)
		},
	}
}

func verifyCertAndKeyExist(certPath, keyPath string) error {
	fi, err := os.Stat(certPath)
	if err != nil {
		return err
	}
	if fi.Size() == 0 {
		return fmt.Errorf("cert file should not be empty")
	}

	fi, err = os.Stat(keyPath)
	if err != nil {
		return err
	}
	if fi.Size() == 0 {
		return fmt.Errorf("key file should not be empty")
	}
	return nil
}

func separateConfigMapsIntoNewestAndExpired(configMaps []corev1.ConfigMap) (newestCMs []corev1.ConfigMap, expiredCMs []corev1.ConfigMap) {
	// Group ConfigMaps by GenerateName
	newestConfigMaps := map[string]corev1.ConfigMap{}
	for _, cm := range configMaps {
		if _, ok := newestConfigMaps[cm.GenerateName]; !ok {
			newestConfigMaps[cm.GenerateName] = cm
			continue
		}
		if cm.CreationTimestamp.After(newestConfigMaps[cm.GenerateName].CreationTimestamp.Time) {
			newestConfigMaps[cm.GenerateName], cm = cm, newestConfigMaps[cm.GenerateName]
		}
		expiredCMs = append(expiredCMs, cm)
	}

	for _, v := range newestConfigMaps {
		newestCMs = append(newestCMs, v)
	}

	return newestCMs, expiredCMs
}

type argument struct {
	generateName string
	url          *url.URL
}

func newArgument(s string) (*argument, error) {
	splitStrings := strings.SplitN(s, ":", 2)
	if len(splitStrings) != 2 {
		return nil, fmt.Errorf("%s is an invalid argument, should match configMapName:url", s)
	}

	url, err := url.Parse(splitStrings[1])
	if err != nil {
		return nil, err
	}

	if strings.ToLower(url.Scheme) != "https" {
		return nil, fmt.Errorf("URL Scheme must be HTTPS")
	}

	arg := &argument{
		generateName: splitStrings[0],
		url:          url,
	}

	return arg, nil
}

func getHttpClient(certPath, keyPath string) (*http.Client, error) {
	cert, err := tls.LoadX509KeyPair(certPath, keyPath)
	if err != nil {
		return nil, err
	}
	return &http.Client{
		Transport: &http.Transport{
			TLSClientConfig: &tls.Config{
				InsecureSkipVerify: true,
				Certificates:       []tls.Certificate{cert},
			},
		},
	}, nil
}

func requestURLBody(httpClient *http.Client, u *url.URL) ([]byte, error) {
	response, err := httpClient.Do(&http.Request{
		Method: http.MethodGet,
		URL:    u,
	})
	if err != nil {
		return nil, err
	}

	if response.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("%s responded with %d status code instead of %d", u, response.StatusCode, http.StatusOK)
	}

	var b bytes.Buffer
	if _, err := io.Copy(&b, response.Body); err != nil {
		return nil, fmt.Errorf("error reading response body: %v", err)
	}

	return b.Bytes(), nil
}

func populateServingCert(ctx context.Context, client client.Client) error {
	secret := &corev1.Secret{}
	err := client.Get(ctx, types.NamespacedName{Namespace: namespace, Name: pprofSecretName}, secret)
	if err != nil {
		return err
	}

	cert, privateKey, err := getCertAndKey()
	if err != nil {
		return err
	}

	secret.Data[corev1.TLSCertKey] = cert
	secret.Data[corev1.TLSPrivateKeyKey] = privateKey
	return client.Update(ctx, secret)
}

func getCertAndKey() ([]byte, []byte, error) {
	cert := &x509.Certificate{
		SerialNumber: big.NewInt(1658),
		Subject: pkix.Name{
			Organization: []string{"Red Hat, Inc."},
		},
		NotBefore:   time.Now(),
		NotAfter:    time.Now().Add(time.Hour),
		ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
		KeyUsage:    x509.KeyUsageDigitalSignature,
	}

	caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
	if err != nil {
		return nil, nil, err
	}

	caBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, &caPrivKey.PublicKey, caPrivKey)
	if err != nil {
		return nil, nil, err
	}

	caPEM := new(bytes.Buffer)
	pem.Encode(caPEM, &pem.Block{
		Type:  "CERTIFICATE",
		Bytes: caBytes,
	})

	caPrivKeyPEM := new(bytes.Buffer)
	pem.Encode(caPrivKeyPEM, &pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey),
	})

	return caPEM.Bytes(), caPrivKeyPEM.Bytes(), nil
}
